#  Copyright (c) 2023, Apple Inc. All rights reserved.
#
#  Use of this source code is governed by a BSD-3-clause license that can be
#  found in the LICENSE.txt file or at https://opensource.org/licenses/BSD-3-Clause

from typing import List, Set, Tuple

import numpy as np

from coremltools.converters.mil._deployment_compatibility import AvailableTarget
from coremltools.converters.mil.frontend import _utils
from coremltools.converters.mil.mil import Block
from coremltools.converters.mil.mil import Builder as mb
from coremltools.converters.mil.mil import Operation, Var, types
from coremltools.converters.mil.mil.block import is_current_opset_version_compatible_with
from coremltools.converters.mil.mil.passes.graph_pass import AbstractGraphPass
from coremltools.converters.mil.mil.passes.helper import (
    _check_child_op_type,
    _check_no_output_connection,
    block_context_manager,
)
from coremltools.converters.mil.mil.passes.pass_registry import register_pass
from coremltools.optimize import _utils as optimize_utils


@register_pass(namespace="common")
class merge_affine_dequantize_with_consecutive_ops(AbstractGraphPass):
    """
    This graph pass does const folding to a chain of supported ops starts with a
    ``constexpr_affine_dequantize`` op. More types of op are supported when quantization
    is tensor-wise, and only a subset is supported for channel-wise. For example

    .. code-block::

        Input graph:
            data -> constexpr_affine_dequantize -> transpose -> expand_dims -> out

        Output graph:
            new_data -> constexpr_affine_dequantize -> out

    where ``new_data`` is computed by ``data -> transpose -> expand_dims``.

    Note that, the graph pass only supports const folding of a single linked list pattern.
    For example, the following pattern will not be changed

    .. code-block::

              |-> constexpr_affine_dequantize -> transpose -> out
        data -|
              |-> constexpr_affine_dequantize -> reshape -> out_2

    since the quantized data is used by multiple ``constexpr``
    """

    SUPPORTED_OP_TYPES_PER_TENSOR = {
        "transpose",
        "reshape",
        "expand_dims",
        "squeeze",
    }
    SUPPORTED_OP_TYPES_PER_CHANNEL = {"transpose"}
    assert SUPPORTED_OP_TYPES_PER_CHANNEL.issubset(
        SUPPORTED_OP_TYPES_PER_TENSOR
    ), "If an op can merge with channel-wise quantization, then it must also be able to merge with tensor-wise quantization"

    def apply(self, prog):
        for f in prog.functions.values():
            block_changed = True
            while block_changed:
                block_changed = self.merge_affine_dequantize_with_consecutive_ops_block(f)

    @block_context_manager
    def merge_affine_dequantize_with_consecutive_ops_block(self, block: Block):
        fusion_occurred = False
        for op in list(block.operations):
            if op.enclosing_block is None:
                continue

            for b in op.blocks:
                block_changed = True
                while block_changed:
                    block_changed = self.merge_affine_dequantize_with_consecutive_ops_block(b)

            if op.op_type != "constexpr_affine_dequantize":
                continue

            if self._try_to_transform(op, block):
                fusion_occurred = True
        return fusion_occurred

    @staticmethod
    def _apply_equivalent_transform(val: np.ndarray, op: Operation) -> np.ndarray:
        if (
            op.op_type
            not in merge_affine_dequantize_with_consecutive_ops.SUPPORTED_OP_TYPES_PER_TENSOR
        ):
            raise ValueError(f"unsupported op_type {op.op_type}")

        if op.op_type == "transpose":
            return np.transpose(val, axes=op.perm.val)
        if op.op_type == "reshape":
            return np.reshape(val, op.outputs[0].shape)
        if op.op_type == "expand_dims":
            return np.expand_dims(val, axis=op.axes.val.tolist())
        if op.op_type == "squeeze":
            axes = op.axes
            if axes is None or axes.val is None:
                return np.squeeze(val)
            return np.squeeze(val, axis=tuple(op.axes.val.tolist()))

    @staticmethod
    def search_for_ops_to_fold(
        op: Operation, block: Block, supported_op_types: Set[str]
    ) -> List[Operation]:
        # traverse the graph to get a chain of applicable ops to fold
        ops_to_fold = []
        cursor = op
        while True:
            prev_cursor = cursor
            if cursor.outputs[0] in block.outputs:
                break
            for supported_op_type in supported_op_types:
                if _check_child_op_type(cursor, supported_op_type):
                    ops_to_fold.append(cursor.outputs[0].child_ops[0])
                    cursor = ops_to_fold[-1]
                    break
            if prev_cursor == cursor:
                break
        return ops_to_fold

    @staticmethod
    def _try_to_transform_per_tensor(op: Operation, block: Block) -> bool:
        assert (
            op.scale.rank == 0 and op.zero_point.rank == 0
        ), "The _try_to_transform_per_tensor method should only be used for per-tensor dequantization case"

        ops_to_fold = merge_affine_dequantize_with_consecutive_ops.search_for_ops_to_fold(
            op, block, merge_affine_dequantize_with_consecutive_ops.SUPPORTED_OP_TYPES_PER_TENSOR
        )
        if len(ops_to_fold) == 0:
            return False

        # do the same transformation on the source quantized data
        cursor = op.quantized_data.val
        for op_to_fold in ops_to_fold:
            cursor = merge_affine_dequantize_with_consecutive_ops._apply_equivalent_transform(
                cursor, op_to_fold
            )

        # after transformation, we create a new constexpr_affine_dequantize op and do the replacement
        new_var = _utils._construct_constexpr_dequant_op(
            cursor,
            op.zero_point,
            op.scale,
            op.axis,
            name=ops_to_fold[-1].outputs[0].name,
            before_op=ops_to_fold[-1],
        )
        block.replace_uses_of_var_after_op(
            anchor_op=ops_to_fold[-1],
            old_var=ops_to_fold[-1].outputs[0],
            new_var=new_var,
            force_replace=True,
        )
        block.remove_ops([op] + ops_to_fold)
        return True

    @staticmethod
    def _try_to_transform_per_channel(op: Operation, block: Block) -> bool:
        scale = op.scale
        zero_point = op.zero_point
        # positively canonicalize axis for easier manipulation later on
        axis = op.axis.val if op.axis.val >= 0 else op.axis.val + op.quantized_data.rank

        ops_to_fold = merge_affine_dequantize_with_consecutive_ops.search_for_ops_to_fold(
            op,
            block,
            merge_affine_dequantize_with_consecutive_ops.SUPPORTED_OP_TYPES_PER_CHANNEL,
        )
        if len(ops_to_fold) == 0:
            return False

        # do the same transformation on the source quantized data
        cursor = op.quantized_data.val
        for op_to_fold in ops_to_fold:
            cursor = merge_affine_dequantize_with_consecutive_ops._apply_equivalent_transform(
                cursor, op_to_fold
            )
            if op_to_fold.op_type == "transpose":
                axis = np.where(op_to_fold.perm.val == axis)[0][0]

        # after transformation, we create a new constexpr_affine_dequantize op and do the replacement
        new_var = mb.constexpr_affine_dequantize(
            quantized_data=cursor,
            zero_point=zero_point,
            scale=scale,
            axis=axis,
            name=ops_to_fold[-1].outputs[0].name,
            before_op=ops_to_fold[-1],
        )
        block.replace_uses_of_var_after_op(
            anchor_op=ops_to_fold[-1],
            old_var=ops_to_fold[-1].outputs[0],
            new_var=new_var,
            force_replace=True,
        )
        block.remove_ops([op] + ops_to_fold)
        return True

    def _try_to_transform(self, op: Operation, block: Block) -> bool:
        # make sure quantized_data only feeds into a single op
        if len(op.quantized_data.child_ops) != 1:
            return False

        if op.scale.rank == 0 and op.zero_point.rank == 0:
            return self._try_to_transform_per_tensor(op, block)
        else:
            return self._try_to_transform_per_channel(op, block)


@register_pass(namespace="common")
class int_op_canonicalization(AbstractGraphPass):
    """
    For general quantized operators, in Core ML, we represent them as
    ``dequantize -> the floating-point version of this operator -> quantize``,
    because mathematically it is the floating-point tensor rather than
    its quantized integer representation that gets operated upon.

    For some quantized operators that do not involve floating-point arithmetic,
    however, it is unnecessary to prepend ``dequantize`` and append ``quantize``.
    Examples are:

    * reshape
    """

    INT_OP_TYPES_AND_OPSET_VERSIONS = {"reshape": {AvailableTarget.iOS17}}

    def apply(self, prog):
        for f in prog.functions.values():
            self._canonicalize_int_ops_block(f)

    @block_context_manager
    def _canonicalize_int_ops_block(self, block: Block):
        def apply_block(block: Block) -> bool:
            for op in list(block.operations):
                for b in op.blocks:
                    self._canonicalize_int_ops_block(b)

                matched_ops = self.match_pattern(op)
                if matched_ops is not None:
                    dequantize, quantize = matched_ops
                    # has to break as the downstream iterator is affected
                    if self.try_to_transform(dequantize, op, quantize):
                        return True

            return False

        need_transformation = True
        while need_transformation:
            need_transformation = apply_block(block)

    def match_pattern(self, op: Operation) -> Tuple[Operation, Operation]:
        if (
            op.op_type not in self.INT_OP_TYPES_AND_OPSET_VERSIONS
            or op.opset_version not in self.INT_OP_TYPES_AND_OPSET_VERSIONS[op.op_type]
        ):
            return None

        # make sure the input is quantized
        dequantize = op.x.op
        if dequantize is None or dequantize.op_type != "dequantize":
            return None

        # make sure the output is quantized
        if not _check_child_op_type(op, "quantize"):
            return None
        quantize = op.outputs[0].child_ops[0]

        # we do not have to check block output, because:
        # * for dequantize, it is ok to connect to block output, since our
        #   transformation method `try_to_transform` is able to deal with that
        # * for op, checking child op has made sure it has only 1 child
        #   and connects to quantize, i.e. it cannot connect to block output

        return dequantize, quantize

    def try_to_transform(self, dequantize: Operation, op: Operation, quantize: Operation) -> bool:
        block: Block = op.enclosing_block

        if not block.try_replace_uses_of_var_after_op(
            anchor_op=quantize,
            old_var=quantize.outputs[0],
            new_var=self.build_int_op(dequantize, op, quantize),
        ):
            return False

        # remove op and quantize here, but not dequantize, since:
        # * all uses of op and quantize has been replaced with the canonicalized one
        # * dequantize may feed to multiple ops, which are not replaced
        #   (if not, then pass dead_code_elimination will eliminate it)
        block.remove_ops([op, quantize])

        return True

    @staticmethod
    def build_int_op(dequantize: Operation, op: Operation, quantize: Operation) -> Var:
        if op.op_type == "reshape":
            return mb.reshape(
                x=dequantize.input,
                shape=op.shape,
                name=quantize.outputs[0].name,
                before_op=op,
            )

        raise NotImplementedError(f"no build method implemented for int op {op.op_type}")


# TODO (rdar://107718371): remove this pass after implementing QuantizedVar
@register_pass(namespace="common")
class nullify_redundant_quantization_zero_point(AbstractGraphPass):
    """
    In Core ML quantization, the performance is better when ``zero point = 0``,
    so we try to make ``zero point = 0`` if possible:

    * ``zero point = -128``
        * this must be an int8 quantization
        * equivalent to uint8 quantization with 0 zero point
    * ``zero point = 128``
        * this must be an uint8 quantization
        * equivalent to int8 quantization with 0 zero point

    Since ``zero point = 0`` is equivalent to ``zero point = None`` in Core ML semantics,
    we further canonicalize to ``zero point = None`` to:

    * make further graph passes easier
    * avoid serializing trivial 0

    The ``zero point = 0`` case can be canonicalized trivially

    .. code-block::

        Input op:

            quantize/dequantize(zero_point=0)

        Output op:

            quantize/dequantize(zero_point=None)

    To guarantee the conservation of output regardless the zero-point shift
    in ``zero point = ±128`` cases, we would only transform:

    * const dequantize, where we fuse the zero-point shift into the const

    .. code-block::

        Input op:

            dequantize(input=const, zero_point=±128)

        Output op:

            dequantize(input=const∓128, zero_point=None)

    * ``quantize -> dequantize``, where we nullify both simultaneously

    .. code-block::

        Input graph:

            input -> quantize(zero_point=±128) -> dequantize(zero_point=±128) -> output

        Output graph:

            input -> quantize(zero_point=None) -> dequantize(zero_point=None) -> output
    """

    def apply(self, prog):
        for f in prog.functions.values():
            self._nullify_redundant_quantization_zero_point_block(f)

    @block_context_manager
    def _nullify_redundant_quantization_zero_point_block(self, block: Block):
        def apply_block(block: Block) -> bool:
            fusion_occurred = False
            for op in list(block.operations):
                if op.enclosing_block is None:
                    continue

                for b in op.blocks:
                    self._nullify_redundant_quantization_zero_point_block(b)

                # no need to break, since only the current op gets changed
                self.try_transform_zp0(op)
                self.try_transform_zp128_const_dequantize(op)

                # has to break as the downstream iterator is affected
                if self.try_transform_zp128_quantize_dequantize(op):
                    fusion_occurred = True

            return fusion_occurred

        need_transformation = True
        while need_transformation:
            need_transformation = apply_block(block)

    @staticmethod
    def try_transform_zp0(op: Operation) -> bool:
        if op.op_type not in ("quantize", "dequantize"):
            return False

        zero_point = op.zero_point
        # if already no zero point, no need for further nullification
        if zero_point is None:
            return False
        zero_point = zero_point.val

        if not np.all(zero_point == 0):
            return False

        new_var: Var
        if op.op_type == "quantize":
            new_var = mb.quantize(
                input=op.input,
                scale=op.scale,
                axis=op.axis,
                output_dtype=op.output_dtype,
                before_op=op,
            )
        else:
            new_var = mb.dequantize(
                input=op.input,
                scale=op.scale,
                axis=op.axis,
                before_op=op,
            )

        block: Block = op.enclosing_block
        if not block.try_replace_uses_of_var_after_op(
            anchor_op=op, old_var=op.outputs[0], new_var=new_var
        ):
            return False
        block.remove_ops([op])

        return True

    @staticmethod
    def try_transform_zp128_const_dequantize(op: Operation) -> bool:
        if op.op_type != "dequantize":
            return False

        zero_point = op.zero_point
        # if already no zero point, no need for further nullification
        if zero_point is None:
            return False
        zero_point = zero_point.val

        is_negative_128 = np.all(zero_point == -128)
        is_positive_128 = np.all(zero_point == 128)
        if not (is_negative_128 or is_positive_128):
            return False

        input = op.input.val
        if input is None:
            return False
        if is_negative_128:
            input = np.uint8(np.int16(input) + 128)
        else:
            input = np.int8(np.int16(input) - 128)

        new_var = mb.dequantize(
            input=input,
            scale=op.scale,
            axis=op.axis,
            before_op=op,
        )

        block: Block = op.enclosing_block
        if not block.try_replace_uses_of_var_after_op(
            anchor_op=op, old_var=op.outputs[0], new_var=new_var
        ):
            return False
        block.remove_ops([op])

        return True

    @staticmethod
    def try_transform_zp128_quantize_dequantize(op: Operation) -> bool:
        if op.op_type != "quantize":
            return False

        zero_point = op.zero_point
        # if already no zero point, no need for further nullification
        if zero_point is None:
            return False
        zero_point = zero_point.val

        is_negative_128 = np.all(zero_point == -128)
        is_positive_128 = np.all(zero_point == 128)
        if not (is_negative_128 or is_positive_128):
            return False

        if not _check_child_op_type(op, "dequantize"):
            return False
        dequantize_op = op.outputs[0].child_ops[0]

        dequantize_zero_point = dequantize_op.zero_point
        if dequantize_zero_point is None:
            return False
        dequantize_zero_point = dequantize_zero_point.val

        if not np.all(dequantize_zero_point == (-128 if is_negative_128 else 128)):
            return False

        new_quantize = mb.quantize(
            input=op.input,
            scale=op.scale,
            axis=op.axis,
            output_dtype="uint8" if is_negative_128 else "int8",
            before_op=dequantize_op,
        )
        new_dequantize = mb.dequantize(
            input=new_quantize,
            scale=dequantize_op.scale,
            axis=dequantize_op.axis,
            before_op=dequantize_op,
        )

        block: Block = op.enclosing_block
        if not block.try_replace_uses_of_var_after_op(
            anchor_op=dequantize_op,
            old_var=dequantize_op.outputs[0],
            new_var=new_dequantize,
        ):
            return False
        block.remove_ops([op, dequantize_op])
        return True


@register_pass(namespace="common")
class dequantize_quantize_pair_elimination(AbstractGraphPass):
    """
    When a ``dequantize`` is followed by an identical ``quantize`` (same scale,
    zero point, axis), they cancel out and can be eliminated

    .. code-block::

        Input graph:
            input -> dequantize -> quantize -> output

        Output graph:
            input -> output

    When the pattern has branches (dequantize has multiple children), we cannot
    eliminate the whole pair, but can still shorten the path. More specifically:

    .. code-block::

        Input graph:
            op1 -> dequantize -> quantize -> op2
                         |
                         |-> some_other_op

        Output graph:
            op1 -> dequantize -> some_other_op
             |
             |-> op2

    PS: On the other hand, the reversed pattern, i.e., ``quantize -> dequantize``,
    is not redundant, since that is the pattern which naturally occurs when a
    quantized op is converted.
    In current activation quantization conversion, a quantized op becomes

    .. code-block::

        dequantize -> regular op -> quantize

    so if we have a sequence of quantized ops, we will get

    .. code-block::

        dequantize -> regular op1 -> quantize -> dequantize -> regular op2 -> quantize

    The ``quantize -> dequantize`` pair in the middle is not redundant, even if
    they have identical scales and zero points and axes, since removing them will lead to
    loss of information about the quantization parameters of the output var of op1
    """

    def apply(self, prog):
        for f in prog.functions.values():
            self._dequantize_quantize_pair_elimination_block(f)

    @block_context_manager
    def _dequantize_quantize_pair_elimination_block(self, block):
        def apply_block(block: Block) -> bool:
            fusion_occurred = False
            for op in list(block.operations):
                if op.enclosing_block is None:
                    continue

                for b in op.blocks:
                    self._dequantize_quantize_pair_elimination_block(b)

                # has to break as the downstream iterator is affected
                if self.try_dequantize_quantize_pair_elimination(op):
                    fusion_occurred = True
            return fusion_occurred

        need_transformation = True
        while need_transformation:
            need_transformation = apply_block(block)

    @staticmethod
    def try_dequantize_quantize_pair_elimination(op: Operation) -> bool:
        def _check_quantize_removable(quantize_op: Operation) -> bool:
            if np.any(op.scale.val != quantize_op.scale.val):
                return False

            is_dequantize_zp_present = op.zero_point is not None
            is_quantize_zp_present = quantize_op.zero_point is not None
            if is_dequantize_zp_present != is_quantize_zp_present:
                return False
            if is_dequantize_zp_present and is_quantize_zp_present:
                if np.any(op.zero_point.val != quantize_op.zero_point.val):
                    return False

            is_dequantize_axis_present = op.axis is not None
            is_quantize_axis_present = quantize_op.axis is not None
            if is_dequantize_axis_present != is_quantize_axis_present:
                return False
            if is_dequantize_axis_present and is_quantize_axis_present:
                if op.axis.val != quantize_op.axis.val:
                    return False

            return True

        if op.op_type != "dequantize":
            return False

        if op.outputs[0] in op.enclosing_block.outputs:
            return False

        any_quantize_removed: bool = False
        for child_op in op.outputs[0].child_ops:
            if child_op.op_type == "quantize" and _check_quantize_removable(child_op):
                block: Block = op.enclosing_block
                if block.try_replace_uses_of_var_after_op(
                    anchor_op=child_op,
                    old_var=child_op.outputs[0],
                    new_var=op.input,
                ):
                    block.remove_ops([child_op])
                    any_quantize_removed = True
        if any_quantize_removed and len(op.outputs[0].child_ops) == 0:
            # Remove the dequant op if all its children quantize ops got removed.
            block.remove_ops([op])
        return any_quantize_removed


@register_pass(namespace="common")
class distributive_quantized_binary_op_scale_normalization(AbstractGraphPass):
    """
    In the backend, for better performance, quantized op can have 1 input scale
    fused within the quantized op kernel. For binary ops, there are 2 inputs,
    but only 1 can get fused. For example, for quantized ``add``

    .. code-block::

        MIL graph (consists of MIL ops):

            dequantize(x, s_x, zp_x) -|
            x_fp = (x - zp_x) * s_x   |
                                      |->  add(x_fp, y_fp)   -> quantize(z_fp, s_z, zp_z)
            dequantize(y, s_y, zp_y) -|   z_fp = x_fp + y_fp      z = z_fp / s_z + zp_z
            y_fp = (y - zp_y) * s_y

        Backend graph (consists of backend instructions, usually including + - * / and fused *+):

            x_shift = x - zp_x -------------------------|
                                                        |-> z_fp = s_x * x_shift + y_fp -> z = z_fp / s_z + zp_z
            y_shift = y - zp_y -> y_fp = s_y * y_shift -|

    Where ``x`` and ``y`` are the inputs, ``z`` is the output,
    ``s`` and ``zp`` are the corresponding scale and zero point.

    The reason why fusing one scale leads to better performance is,
    instead of 2 instructions ``x_fp = s_x * x_shift`` and ``z_fp = x_fp + y_fp``,
    a single ``z_fp = x_shift * s_x + y_fp`` instruction achieves the same result.

    In this pass, we normalize ``s_y`` to 1, so the ``y_fp = s_y * y_shift``
    instruction can get skipped as well, leading to even better performance.
    This pass only applies to distributive binary ops such as ``add`` and ``sub``

    Appendix: Mathematical and Computer-Scientific Details

    Mathematically, for a binary operator ``.op.``

    .. code-block::

        z_fp = (x - zp_x) * s_x .op. (y - zp_y) * s_y
             = s_y * [(x - zp_x) * s_x/s_y .op. (y - zp_y) * 1]

    The corresponding pseudo code is

    .. code-block::

        # before
        z_fp = (x - zp_x) * s_x .op. (y - zp_y) * s_y
        z = z_fp / s - zp_z

        # after
        z_fp_modified = (x - zp_x) * s_x/s_y .op. (y - zp_y) * 1.0
        z = z_fp_modified / (s_z/s_y) - zp_z

    Concretely, as a MIL graph pass

    .. code-block::

        Input graph:
            dequantize(scale=s_x) -|
                                   |-> op -> quantize(scale=s_z)
            dequantize(scale=s_y) -|

        Output graph:
            dequantize(scale=s_x/s_y) -|
                                       |-> op -> quantize(scale=s_z/s_y)
            dequantize(scale=1.0)     -|

    PS: we only support scalar ``s_y`` for now. If ``s_y`` is not scalar but
    ``s_x`` is, we would swap ``x`` and ``y``. Support for both-vector case is
    to be explored, due to the broadcasting complication.
    """

    DISTRIBUTIVE_BINARY_OPS = {"add", "sub"}

    def apply(self, prog):
        @block_context_manager
        def apply_block(block: Block):
            for op in list(block.operations):
                for b in op.blocks:
                    apply_block(b)

                matched_ops = self.match_pattern(op)
                if matched_ops is not None:
                    dequantize_x, dequantize_y, quantize_z = matched_ops
                    self.try_to_transform(op, dequantize_x, dequantize_y, quantize_z)

        for f in prog.functions.values():
            apply_block(f)

    def match_pattern(self, op: Operation) -> Tuple[Operation, Operation, Operation]:
        """
        try to match distributive quantized binary op:
                ...
                 ^
                 |
            dequantize(x) -|
                           |-> op(x, y) (-> relu) -> quantize(z)
            dequantize(y) -|
                 |
                 v
                ...

        return dequantize_x, dequantize_y, quantize_z for further transformation

        return None if no match
        """
        # make sure the op is distributive
        if op.op_type not in self.DISTRIBUTIVE_BINARY_OPS:
            return None

        # quantized op may be fused with relu
        # relu would not affect distributivity
        tail_op = op
        if _check_child_op_type(op, "relu"):
            tail_op = op.outputs[0].child_ops[0]

        # make sure the inputs are quantized
        dequantize_x = op.x.op
        dequantize_y = op.y.op
        if (
            dequantize_x is None
            or dequantize_y is None
            or dequantize_x.op_type != "dequantize"
            or dequantize_y.op_type != "dequantize"
        ):
            return None

        # make sure the output is quantized
        if not _check_child_op_type(tail_op, "quantize"):
            return None
        quantize_z = tail_op.outputs[0].child_ops[0]

        # make sure the intermediate results are not block outputs
        # since we only guarantee conservation of z
        if not _check_no_output_connection(
            op.enclosing_block, [dequantize_x, dequantize_y, op, tail_op, quantize_z]
        ):
            return None

        return dequantize_x, dequantize_y, quantize_z

    def try_to_transform(
        self, op: Operation, dequantize_x: Operation, dequantize_y: Operation, quantize_z: Operation
    ) -> bool:
        """
        given dequantize_x, dequantize_y, quantize_z, transform by
            z_fp = (x - zp_x) * s_x/s_y .op. (y - zp_y) * 1.0
            z = z_fp / (s_z/s_y) - zp_z

        See the class doc for details
        """
        block = quantize_z.enclosing_block

        new_s_x, new_s_z = self.try_to_divide(dequantize_x, dequantize_y, quantize_z)
        # if s_y cannot be used to divide, then swap x and y and try again
        if new_s_x is None and new_s_z is None:
            dequantize_x, dequantize_y = dequantize_y, dequantize_x
            new_s_x, new_s_z = self.try_to_divide(dequantize_x, dequantize_y, quantize_z)
            # after swap, if still cannot divide, then give up
            if new_s_x is None and new_s_z is None:
                return False

        def convert_mil_float_dtype_to_np(mil_dtype):
            if mil_dtype == types.fp16 or mil_dtype == "float16":
                np_dtype = np.float16
            else:
                np_dtype = np.float32
            return np_dtype

        new_s_x_dtype = convert_mil_float_dtype_to_np(dequantize_x.scale.val.dtype)
        new_s_y_dtype = convert_mil_float_dtype_to_np(dequantize_y.scale.val.dtype)
        new_s_z_dtype = convert_mil_float_dtype_to_np(quantize_z.scale.val.dtype)

        # insert normalized new_dequantize_x and new_dequantize_y before op
        new_dequantize_x = mb.dequantize(
            input=dequantize_x.input,
            scale=new_s_x_dtype(new_s_x),
            zero_point=dequantize_x.zero_point,
            axis=dequantize_x.axis,
            before_op=op,
        )
        new_dequantize_y = mb.dequantize(
            input=dequantize_y.input,
            scale=new_s_y_dtype(1)
            if dequantize_y.axis is None
            else np.full(dequantize_y.scale.val.shape, 1.0),
            zero_point=dequantize_y.zero_point,
            axis=dequantize_y.axis,
            before_op=op,
        )

        # insert normalized new_quantize_z before quantize_z
        new_quantize_z = mb.quantize(
            input=quantize_z.input,
            scale=new_s_z_dtype(new_s_z),
            zero_point=quantize_z.zero_point,
            axis=quantize_z.axis,
            output_dtype=quantize_z.output_dtype,
            before_op=quantize_z,
        )
        if not (
            # replace dequantize_x and dequantize_y with the normalized ones
            # in the range of (new_dequantize_x, op] and (new_dequantize_y, op]
            # in case dequantize_x and dequantize_y also feed to other ops
            # which should not get altered by this transformation
            block.try_replace_uses_of_var_after_op(
                anchor_op=new_dequantize_x.op,
                end_op=op,
                old_var=dequantize_x.outputs[0],
                new_var=new_dequantize_x,
            )
            and block.try_replace_uses_of_var_after_op(
                anchor_op=new_dequantize_y.op,
                end_op=op,
                old_var=dequantize_y.outputs[0],
                new_var=new_dequantize_y,
            )
            # replace quantize_z with the normalized one
            and block.try_replace_uses_of_var_after_op(
                anchor_op=quantize_z, old_var=quantize_z.outputs[0], new_var=new_quantize_z
            )
        ):
            return False

        # remove quantize_z here, but not dequantize_x and dequantize_y, since:
        # * all uses of quantize_z has been replaced with the normalized one
        # * dequantize_x and dequantize_y may feed to multiple ops, which are not replaced
        #   (if not, then pass dead_code_elimination will eliminate them)
        block.remove_ops([quantize_z])

        return True

    def try_to_divide(
        self,
        dequantize_x: Operation,
        dequantize_y: Operation,
        quantize_z: Operation,
    ) -> Tuple[np.ndarray, np.ndarray]:
        """
        compute s_x/s_y and s_z/s_y, return the results if succeeds, else None

        The broadcast rule is very complicated:
        1. Broadcast s_x to x, s_y to y, s_z to z, according to axes
        2. Broadcast s_x and s_y
        3. Perform s_x/s_y and s_z/s_y
        4. De-broadcast s_x/s_y and s_z/s_y down to vectors according to axes,
           raise exception if impossible to de-broadcast

        As a result, for now we only handle the scalar s_y case
        """

        # TODO (rdar://109170887): explore vector s_y
        if dequantize_y.axis is not None:
            return None, None

        s_x_fp32 = np.float32(dequantize_x.scale.val)
        s_y_fp32 = np.float32(dequantize_y.scale.val)
        s_z_fp32 = np.float32(quantize_z.scale.val)

        s_x_d_s_y = s_x_fp32 / s_y_fp32
        s_z_d_s_y = s_z_fp32 / s_y_fp32

        if (
            self.overflow_fp16(s_x_d_s_y)
            or self.underflow_fp16(s_x_d_s_y)
            or self.overflow_fp16(s_z_d_s_y)
            or self.underflow_fp16(s_z_d_s_y)
        ):
            return None, None

        return s_x_d_s_y, s_z_d_s_y

    @staticmethod
    def overflow_fp16(x: np.ndarray) -> bool:
        return np.max(np.abs(x)) > 65504

    @staticmethod
    def underflow_fp16(x: np.ndarray) -> bool:
        return np.min(np.abs(x)) < np.nextafter(0.0, 1.0, dtype=np.float16)


@register_pass(namespace="common")
class dequantize_to_constexpr(AbstractGraphPass):
    """
    ``dequantize`` op with constant input is equivalent to ``constexpr_affine_dequantize``.
    This is one of the canonicalization pass that transforms all such
    ``dequantize`` ops to respective ``constexpr_affine_dequantize`` ops.

    .. code-block::

        Input graph:

            dequantize(input=const) -> downstream op

        Output graph:

            constexpr_affine_dequantize -> downstream op

    This pass is being performed because constant tensors being propagated
    through ``dequantize`` op would be serialized in bloated/decompressed fashion,
    whereas with ``constexpr_affine_dequantize``,
    constant weights/tensors remain compressed at serialization.
    """

    def apply(self, prog):
        @block_context_manager
        def apply_block(block):
            for op in list(block.operations):
                for b in op.blocks:
                    apply_block(b)

                if self.is_valid_op(op):
                    self.transform_op(op)

        for f in prog.functions.values():
            apply_block(f)

    def is_valid_op(self, op):
        return op.op_type == "dequantize" and op.can_materialize_val()

    def transform_op(self, op):
        quantized_data = op.input.val

        scale = op.scale.val

        zero_point = None
        if op.zero_point is not None:
            zero_point = op.zero_point.val
        else:
            zero_point = np.int8(0) if op.input.dtype == types.int8 else np.uint8(0)

        axis = None if op.axis is None else op.axis.val

        new_var = _utils._construct_constexpr_dequant_op(
            quantized_data,
            zero_point,
            scale,
            axis,
            name=op.name + "_affine_dequantized",
            before_op=op,
        )

        block = op.enclosing_block
        block.replace_uses_of_var_after_op(anchor_op=op, old_var=op.outputs[0], new_var=new_var)
        block.remove_ops([op])


@register_pass(namespace="common")
class reorder_lut_per_channel_scale(AbstractGraphPass):
    """
    The lut with per-channel-scale was represented as the following op combinations:
        weight = constexpr_lut_to_dense()
        weight = constexpr_blockwise_shift_scale(weight)
        output = linear/matmul/conv(x, weight)
    However, for ANE, it requires the scale to be after the linear/matmul/conv, which is:
        weight = constexpr_lut_to_dense()
        unscaled_output = linear/matmul(x, weight)
        output = mul(unscaled_output, scale)
    This graph pass finds the lut with per-channel-scale and move the scale to be ANE-friendly.
    """

    _OPS_SUPPORT_MOVE_SCALE = {"linear", "matmul", "conv"}

    def apply(self, prog):
        @block_context_manager
        def apply_block(block: Block):
            for op in list(block.operations):
                for b in op.blocks:
                    apply_block(b)

                if op.op_type == "constexpr_lut_to_dense" and len(op.outputs[0].child_ops) == 1:
                    child_op = op.outputs[0].child_ops[0]
                    if child_op.op_type == "constexpr_blockwise_shift_scale":
                        # Can move the scale when the constexpr op is only used to scale the weight.
                        has_offset = child_op.offset is not None and child_op.offset.val.any()
                        if types.is_float(child_op.data.dtype) and not has_offset:
                            self._reorder_lut_per_channel_scale(block, op)

        for f in prog.functions.values():
            apply_block(f)

    def _reorder_lut_per_channel_scale(self, block: Block, lut_op: Operation):
        # The original order is lut_op -> scale_op -> output_op.
        scale_op = lut_op.outputs[0].child_ops[0]

        # Only move the scale when all ops that consume this scale op support moving.
        for output_op in scale_op.outputs[0].child_ops:
            if output_op.op_type not in self._OPS_SUPPORT_MOVE_SCALE:
                return

            # Only the scale on output axis could be moved to get mathematically equivalent results.
            scale_val: np.ndarray = scale_op.scale.val
            output_axis = optimize_utils.select_input_output_channel_axis(scale_op)[1]
            if output_axis is None:
                return
            if output_axis < 0:
                output_axis += len(scale_val.shape)
            for axis, dim_size in enumerate(scale_val.shape):
                if axis != output_axis and dim_size != 1:
                    return

        for output_op in list(scale_op.outputs[0].child_ops):
            self._help_move_scale(block, lut_op, scale_op, output_op)
            block.remove_ops([output_op])
        block.remove_ops([scale_op])

    @staticmethod
    def _help_move_scale(
        block: Block, lut_op: Operation, scale_op: Operation, output_op: Operation
    ):
        """Move the scale from `lut_op -> scale_op -> output_op` to `lut_op -> output_op -> mul`."""
        scale_val: np.ndarray = scale_op.scale.val
        inputs = output_op.inputs
        if output_op.op_type == "linear":
            scale_val = scale_val.T
            inputs["weight"] = lut_op.outputs[0]
            if getattr(output_op, "bias", None) and output_op.bias.val is not None:
                original_bias = output_op.bias.val
                new_bias = (original_bias / np.squeeze(scale_val)).astype(original_bias.dtype)
                inputs["bias"] = new_bias
        elif output_op.op_type == "matmul":
            # Determine if the scaled weight is used by `x` or `y` in matmul.
            if output_op.y == scale_op.outputs[0]:
                if output_op.transpose_y.val is True:
                    scale_val = scale_val.T
                inputs["y"] = lut_op.outputs[0]
            else:
                if output_op.transpose_x.val is True:
                    scale_val = scale_val.T
                inputs["x"] = lut_op.outputs[0]
        else:
            if output_op.op_type != "conv":
                raise AssertionError(
                    "The scale could only be moved for linear/matmul/conv, "
                    f"but got {output_op.op_type}"
                )
            # The weight of conv has C_out at axis=0, but in output the C_out is at axis=1
            scale_val = np.squeeze(scale_val)
            if len(scale_val.shape) > 1:
                # The per-channel-scale should only have one axis with larger than 1 dim size.
                return
            channel_size = 1 if len(scale_val.shape) == 0 else scale_val.shape[0]
            scale_val = scale_val.reshape((1, channel_size, 1, 1))
            inputs["weight"] = lut_op.outputs[0]
            if getattr(output_op, "bias", None) and output_op.bias.val is not None:
                original_bias = output_op.bias.val
                new_bias = (original_bias / np.squeeze(scale_val)).astype(original_bias.dtype)
                inputs["bias"] = new_bias

        # Reconstruct the unscaled output which uses lut output as weight (skip the original scale).
        unscaled_output = getattr(mb, output_op.op_type)(**inputs, before_op=output_op)
        scaled_output = mb.mul(x=unscaled_output, y=scale_val, before_op=output_op)

        # Now the order is lut_op -> unscaled_output -> scaled_output.
        block.replace_uses_of_var_after_op(
            anchor_op=output_op,
            old_var=output_op.outputs[0],
            new_var=scaled_output,
            force_replace=True,  # Need to force replace because it involves replacing constexpr op.
        )


@register_pass(namespace="common")
class canonicalize_quantized_lut_pattern(AbstractGraphPass):
    """
    The quantized lut (e.g. each entry in the LUT is int8) could be represented by two patterns:
        Pattern 1:
            lut(int8) -> constexpr_blockwise_shift_scale -> lut(fp16) -> constexpr_lut_to_dense -> dense(fp16)
        Pattern 2:
            lut(int8) -> constexpr_lut_to_dense -> dense(int8) -> constexpr_blockwise_shift_scale -> dense(fp16)
    Those two patterns are mathematically equivalent when the quantization is per-tensor or per-channel.

    This graph pass makes sure we always use one specific pattern by re-ordering the ops.
    """

    _DEQUANT_FIRST = True  # First dequantize and then depalettize (use pattern 1).

    @staticmethod
    def get_order_to_reverse(expect_dequant_first: bool) -> Tuple[str, str]:
        """Get the wrong order pattern which will be canonicalized."""
        if expect_dequant_first and not is_current_opset_version_compatible_with(
            AvailableTarget.iOS26
        ):
            # This issue was fixed in iOS26.
            # The LUT -> shift_scale op pattern will be reversed.
            wrong_order_op1 = "constexpr_lut_to_dense"
            wrong_order_op2 = "constexpr_blockwise_shift_scale"
        else:
            # The shift_scale -> LUT op pattern will be reversed.
            wrong_order_op1 = "constexpr_blockwise_shift_scale"
            wrong_order_op2 = "constexpr_lut_to_dense"

        return wrong_order_op1, wrong_order_op2

    def apply(self, prog):
        @block_context_manager
        def apply_block(block: Block):
            wrong_order_op1, wrong_order_op2 = self.get_order_to_reverse(self._DEQUANT_FIRST)

            for op in list(block.operations):
                for b in op.blocks:
                    apply_block(b)
                if op.op_type == wrong_order_op1 and len(op.outputs[0].child_ops) == 1:
                    if op.outputs[0].child_ops[0].op_type == wrong_order_op2:
                        self._reorder_quant_lut(block, op)

        for f in prog.functions.values():
            apply_block(f)

    def _reorder_quant_lut(self, block: Block, old_op1: Operation):
        """
        Original order is op1 -> op2 -> output_op, and after reorder it becomes op2 -> op1 -> output_op.
        Here op1 and op2 corresponds to either lut op or quant op, depending on the desired order.
        """
        old_op2 = old_op1.outputs[0].child_ops[0]
        # If the old op has some meaningful info in the name (such as "conv1.weight"), we need to keep it.
        new_op1_name = None if old_op1.op_type in old_op1.name else old_op1.name
        new_op2_name = None if old_op2.op_type in old_op2.name else old_op2.name

        if old_op1.op_type == "constexpr_blockwise_shift_scale":
            # The old_op1 is dequant op and old_op2 is a lut op.
            # The scale and offset from old_op1 is for lut, so the rank need to be adjusted.
            if old_op1.scale.shape[-2:] != (1, 1):
                raise AssertionError(
                    "The quantization on lut must be per-tensor, so last two dims in `scale` should "
                    f"both be 1, but got scale with shape {old_op1.scale.shape}."
                )
            new_scale_shape = old_op1.scale.shape[-2:]
            scale = old_op1.scale.val.reshape(new_scale_shape)
            offset = old_op1.offset
            if offset is not None and offset.val is not None:
                offset = old_op1.offset.val.reshape(new_scale_shape)

            new_op1_args = {"indices": old_op2.indices, "lut": old_op1.data, "before_op": old_op2}
            if new_op1_name is not None:
                new_op1_args["name"] = new_op1_name
            new_op1 = mb.constexpr_lut_to_dense(**new_op1_args)

            new_op2_args = {"data": new_op1, "scale": scale, "offset": offset, "before_op": old_op2}
            if new_op2_name is not None:
                new_op2_args["name"] = new_op2_name
            new_op2 = mb.constexpr_blockwise_shift_scale(**new_op2_args)
        else:
            # The old_op1 is lut op and old_op2 is a dequant op.
            # The scale and offset from old_op2 is for depalettized weight, so the rank need to be adjusted to match
            # the lut's rank.
            new_scale_shape = old_op2.scale.shape + (1, 1)
            scale = old_op2.scale.val.reshape(new_scale_shape)
            offset = old_op2.offset
            if offset is not None and offset.val is not None:
                offset = old_op2.offset.val.reshape(new_scale_shape)

            lut = old_op1.lut
            if any(shape != 1 for shape in new_scale_shape):
                # The lut need to be repeated when necessary. For example, in per-channel-scale, the lut has shape
                # [16, 1, 16, 1], indices has shape [32, 1], and scale has shape [32, 1]. It means every two rows in
                # the weight share a lut, and it's impossible to apply 32 scales to 16 lut tables. So we need to repeat
                # the lut to become [32, 1, 16, 1], and then apply those 32 scales to each row.
                lut = old_op1.lut.val
                if lut is None:
                    return  # Cannot handle the reording when the lut is not const.
                for axis, (scale_shape, lut_shape) in enumerate(zip(new_scale_shape, lut.shape)):
                    if scale_shape > lut_shape:
                        if scale_shape % lut_shape != 0:
                            return  # Skip when lut's shape cannot be repeated to match scale's shape.
                        lut = np.repeat(lut, scale_shape // lut_shape, axis=axis)

            new_op1_args = {"data": lut, "scale": scale, "offset": offset, "before_op": old_op1}
            if new_op1_name is not None:
                new_op1_args["name"] = new_op1_name
            new_op1 = mb.constexpr_blockwise_shift_scale(**new_op1_args)

            new_op2_args = {"indices": old_op1.indices, "lut": new_op1, "before_op": old_op1}
            if new_op2_name is not None:
                new_op2_args["name"] = new_op2_name
            new_op2 = mb.constexpr_lut_to_dense(**new_op2_args)

        block.replace_uses_of_var_after_op(
            anchor_op=old_op2,
            old_var=old_op2.outputs[0],
            new_var=new_op2,
            force_replace=True,  # Need to force replace because it involves replacing constexpr op.
        )
        block.remove_ops([old_op1, old_op2])
